Source code for hysop.core.mpi.redistribute

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Implementation for data transfer/redistribution between topologies

`.. currentmodule : hysop.core.mpi.redistribute

See hysop.operator.redistribute.Redistribute for automatic
redistribute deployment.

* :class:`~RedistributeOperatorBase` abstract base class
* :class:`~RedistributeIntra` for topologies/operators defined
  inside the same mpi communicator
* :class:`~RedistributeInter` for topologies/operators defined
  on two different mpi communicator
* :class:`~RedistributeOverlap` for topologies defined
  inside the same mpi parent communicator and
  with a different number of processes
"""

from hashlib import sha1
import numpy as np
from hysop.constants import Backend, DirectionLabels, MemoryOrdering
from hysop.tools.htypes import check_instance, to_set, first_not_None
from hysop.tools.decorators import debug
from hysop.tools.numpywrappers import npw, slices_empty
from hysop.tools.mpi_utils import get_mpi_order
from hysop.topology.cartesian_topology import Topology, CartesianTopology, TopologyView
from hysop.topology.topology_descriptor import TopologyDescriptor
from hysop.core.mpi.topo_tools import TopoTools
from hysop.core.mpi.bridge import Bridge, BridgeOverlap, BridgeInter
from hysop.operator.base.redistribute_operator import RedistributeOperatorBase
from hysop.core.graph.computational_operator import ComputationalGraphOperator
from hysop.core.graph.graph import op_apply
from hysop import MPI, MPIParams
from hysop.parameters.scalar_parameter import ScalarParameter, TensorParameter

DEBUG_REDISTRIBUTE = 0


def _memcpy(dst, src, target_indices, source_indices, skind=None, tkind=None):
    def _runtime_error():
        msg = "Copy from {} to {} are not handled yet."
        msg = msg.format(src.__class__, dst.__class__)
        raise RuntimeError(msg)

    assert src.dtype == dst.dtype
    skind = src.backend.kind if skind is None else skind
    tkind = dst.backend.kind if tkind is None else tkind

    evt = None
    if skind == Backend.HOST:
        if tkind == Backend.HOST:
            dst[target_indices] = src[source_indices]
        elif tkind == Backend.OPENCL:
            from hysop.backend.device.opencl.opencl_copy_kernel_launchers import (
                OpenClCopyBufferRectLauncher,
            )

            knl = OpenClCopyBufferRectLauncher.from_slices(
                varname="redistribute",
                src=src,
                dst=dst,
                src_slices=source_indices,
                dst_slices=target_indices,
            )
            evt = knl(queue=dst.default_queue)
        else:
            _runtime_error()
    elif skind == Backend.OPENCL:
        from hysop.backend.device.opencl.opencl_copy_kernel_launchers import (
            OpenClCopyBufferRectLauncher,
        )

        if tkind == Backend.HOST:
            knl = OpenClCopyBufferRectLauncher.from_slices(
                varname="redistribute",
                src=src,
                dst=dst,
                src_slices=source_indices,
                dst_slices=target_indices,
            )
            evt = knl(queue=src.default_queue)
        elif tkind == Backend.OPENCL:
            assert src.backend.cl_env is dst.backend.cl_env
            knl = OpenClCopyBufferRectLauncher.from_slices(
                varname="redistribute",
                src=src,
                dst=dst,
                src_slices=source_indices,
                dst_slices=target_indices,
            )
            evt = knl(queue=src.default_queue)
        else:
            _runtime_error()
    else:
        _runtime_error()
    return evt


[docs] class RedistributeIntra(RedistributeOperatorBase): """Data transfer between two operators/topologies. Source and target must: *be CartesianTopology topologies with the same global resolution *be defined on the same communicator *work on the same number of mpi processes """
[docs] @classmethod def can_redistribute(cls, source_topo, target_topo, **kwds): tin = source_topo tout = target_topo # source and target must be CartesianTopology topology defined on HostArrayBackend if not isinstance(source_topo, CartesianTopology): return False if not isinstance(target_topo, CartesianTopology): return False # source and target must have the same global resolution source_res = tin.mesh.grid_resolution target_res = tout.mesh.grid_resolution if not npw.array_equal(source_res, target_res): return False # defined on the same communicator # and work on the same number of mpi process if not TopoTools.compare_comm(tin.parent, tout.parent): return False return True
def __new__(cls, **kwds): return super().__new__(cls, **kwds) def __init__(self, **kwds): """Data transfer between two operators/topologies defined on the same communicator Source and target must: *be defined on the same communicator *work on the same number of mpi process *work with the same global resolution """ # Base class initialisation super().__init__(**kwds) # Warning : comm from io_params will be used as # reference for all mpi communication of this operator. # --> rank computed in refcomm # --> source and target must work inside refcomm # If io_params is None, refcomm will COMM_WORLD.
[docs] @debug def discretize(self): super().discretize() # Dictionnary of discrete field to be sent and received v = self.variable self._vsource = {v: self.input_discrete_fields[v]} self._vtarget = {v: self.output_discrete_fields[v]} # we can create the bridge ifield = self.input_discrete_fields[self.variable] ofield = self.output_discrete_fields[self.variable] source_topo = ifield.topology target_topo = ofield.topology sstate = source_topo.topology_state tstate = target_topo.topology_state if ( (sstate.dim != tstate.dim) or (sstate.axes != tstate.axes) or (sstate.memory_order != tstate.memory_order) ): msg = "Topology state mismatch between source and target." msg += "\nSource topology state:" msg += str(sstate) msg += "\nTarget topology state:" msg += str(tstate) raise RuntimeError(msg) assert all(source_topo.mesh.local_resolution == ifield.resolution) assert all(target_topo.mesh.local_resolution == ofield.resolution) self.bridge = Bridge( source_topo, target_topo, self.dtype, get_mpi_order(ifield.sdata) ) self._rank = self.bridge._rank # dictionnary which maps rank with mpi derived type # for send operations self._send = self.bridge.send_types() # dictionnay which maps rank with mpi derived type # for recieve operations self._receive = self.bridge.recv_types() self._has_requests = False self.dFin = ifield self.dFout = ofield
@op_apply def apply(self, **kwds): # Try different way to send vars? # - Buffered : copy all data into a buffer and send/recv # - Standard : one send/recv dFin, dFout = self.dFin, self.dFout super().apply(**kwds) # --- Standard send/recv --- br = self.bridge # dictionnary which map rank/field name with a receive request self._r_request = {} # dictionnary which map rank/field name with a send request self._s_request = {} basetag = self.mpi_params.rank + 1 # Comm used for send/receive operations # It must contains all proc. of source topo and # target topo. refcomm = self.bridge.comm v = self.variable local_evts = () v_name = v.name # Deal with local copies of data if br.has_local_inter(): dst = self._vtarget[v].sdata src = self._vsource[v].sdata axes = self._vtarget[v].topology_state.axes source_indices = br.local_source_ind() target_indices = br.local_target_ind() evt = _memcpy(dst, src, target_indices, source_indices) if evt is not None: local_evts += (evt,) # Transfers to other mpi processes for rk in self._receive: if rk == self._rank: continue recvtag = basetag * 989 + (rk + 1) * 99 mpi_type = self._receive[rk] dst = self._vtarget[v].sdata assert dst.backend.kind is Backend.HOST self._r_request[v_name + str(rk)] = refcomm.Irecv( [dst.handle, 1, mpi_type], source=rk, tag=recvtag ) self._has_requests = True for rk in self._send: if rk == self._rank: continue sendtag = (rk + 1) * 989 + basetag * 99 mpi_type = self._send[rk] src = self._vsource[v].sdata assert src.backend.kind is Backend.HOST self._s_request[v_name + str(rk)] = refcomm.Issend( [src.handle, 1, mpi_type], dest=rk, tag=sendtag ) self._has_requests = True for evt in local_evts: evt.wait() if self._has_requests: for rk in self._r_request: self._r_request[rk].Wait() for rk in self._s_request: self._s_request[rk].Wait() self._has_requests = False if DEBUG_REDISTRIBUTE: print("resolution, compute_resolution, ghosts, compute_slices") print( dFin.resolution, dFin.compute_resolution, dFin.ghosts, dFin.compute_slices, ) print( dFout.resolution, dFout.compute_resolution, dFout.ghosts, dFout.compute_slices, ) print() print("BEFORE") dFout.print_with_ghosts() dFout.exchange_ghosts() if DEBUG_REDISTRIBUTE: print("AFTER") dFout.print_with_ghosts() mean_in = refcomm.allreduce( dFin.sdata[dFin.compute_slices].sum().get() ) / float(refcomm.size) mean_out = refcomm.allreduce( dFout.sdata[dFout.compute_slices].sum().get() ) / float(refcomm.size) assert npw.isclose(mean_in, mean_out), f"{mean_in} != {mean_out}"
[docs] class RedistributeInter(RedistributeOperatorBase): """Data transfer between two operators/topologies. Source and target must: *be CartesianTopology topologies with the same global resolution *be defined on different communicators """
[docs] @classmethod def can_redistribute(cls, source_topo, target_topo, other_task_id=None, **kwds): tin = source_topo tout = target_topo # source and target are defined on different tasks # (one topology is None) or (there is two topologies on different tasks) if not ( ( isinstance(tin, CartesianTopology) and not isinstance(tout, CartesianTopology) ) or ( isinstance(tout, CartesianTopology) and not isinstance(tin, CartesianTopology) ) ): if (tout is None and tin is None) or ( tout.mpi_params.task_id == tin.mpi_params.task_id ): return False # source and target must have the same global resolution if isinstance(tout, CartesianTopology) and not isinstance( tin, CartesianTopology ): tout_id = tout.mpi_params.task_id _is_source, _is_dest = False, True _other_task = other_task_id domain = tout.domain other_resol = npw.zeros_like(tout.mesh.grid_resolution) my_resol = tout.mesh.grid_resolution elif isinstance(tin, CartesianTopology) and not isinstance( tout, CartesianTopology ): tin_id = tin.mpi_params.task_id _is_source, _is_dest = True, False _other_task = other_task_id domain = tin.domain other_resol = npw.zeros_like(tin.mesh.grid_resolution) my_resol = tin.mesh.grid_resolution elif isinstance(tout, CartesianTopology) and isinstance(tin, CartesianTopology): tout_id = tout.mpi_params.task_id tin_id = tin.mpi_params.task_id _is_source, _is_dest = True, True _other_task = other_task_id domain = tin.domain other_resol = tin.mesh.grid_resolution my_resol = tout.mesh.grid_resolution else: return False if domain.task_rank() == 0: if _is_source and not _is_dest: domain.parent_comm.send( tin.mesh.grid_resolution, dest=domain.task_root_in_parent(_other_task), ) other_resol = domain.parent_comm.recv( source=domain.task_root_in_parent(_other_task) ) if _is_dest and not _is_source: other_resol = domain.parent_comm.recv( source=domain.task_root_in_parent(_other_task) ) domain.parent_comm.send( tout.mesh.grid_resolution, dest=domain.task_root_in_parent(_other_task), ) other_resol = domain.task_comm.bcast(other_resol, root=0) if not npw.array_equal(my_resol, other_resol): return False return True
def __new__(cls, other_task_id=None, **kwds): return super().__new__(cls, **kwds) def __init__(self, mpi_params=None, other_task_id=None, **kwds): """ Data transfer between two operators/topologies. Source and target must: *be CartesianTopology topologies with the same global resolution *be defined on different communicators """ if not kwds["variable"] is None: self.fake_init = False # Base class initialisation super().__init__(mpi_params=mpi_params, **kwds) self._other_task_id = other_task_id self._synchronize(kwds["source_topo"], kwds["target_topo"]) else: # Fake init. Should be called again later self.fake_init = True self.initialized = True self.name = "TempName" self.pretty_name = "TempName" self.mpi_params = mpi_params self._input_fields_to_dump = [] self._output_fields_to_dump = [] self.input_fields = {} self.output_fields = {} self.input_params = {} self.output_params = {} def _synchronize(self, tin, tout): """Ensure that the two redistributes are operating on the same variable""" v = self.variable in_name, out_name = "" if tin is None else v.name, ( "" if tout is None else v.name ) domain = first_not_None((tin, tout)).domain # Exchange names on root ranks first if domain.task_rank() == 0 and in_name != out_name: rcv_name = domain.parent_comm.sendrecv( v.name, sendtag=self._other_task_id, recvtag=first_not_None((tin, tout)).mpi_params.task_id, dest=domain.task_root_in_parent(self._other_task_id), source=domain.task_root_in_parent(self._other_task_id), ) in_name, out_name = ( rcv_name if _ == "" else _ for _ in (in_name, out_name) ) # then broadcast other's names on local ranks if not tout is None: in_name = tout.mpi_params.comm.bcast(in_name, root=0) if not tin is None: out_name = tin.mpi_params.comm.bcast(out_name, root=0) assert in_name == out_name and in_name == v.name
[docs] def get_preserved_input_fields(self): """This Inter-communicator redistribute is preserving the output fields. output fields are invalidated on other topologies only if field is not also an input """ o_f, i_f = self.output_fields, self.input_fields return { f for f in o_f.keys() if (not o_f[f] is None) and (f in i_f.keys() and not i_f[f] is None) }
[docs] def output_topology_state(self, output_field, input_topology_states): """ Determine a specific output discrete topology state given all input discrete topology states. Must be redefined to help correct computational graph generation. By default, just return first input state if all input states are all the same. If input_topology_states are different, raise a RuntimeError as default behaviour. Operators altering the state of their outputs *have* to override this method. The state may include transposition state, memory order and more. see hysop.topology.transposition_state.TranspositionState for the complete list. """ from hysop.fields.continuous_field import Field from hysop.topology.topology import TopologyState check_instance(output_field, Field) check_instance(input_topology_states, dict, keys=Field, values=TopologyState) assert output_field in self.output_fields.keys() assert len(set(input_topology_states.keys())) == 0 or set( input_topology_states.keys() ) == set(self.input_fields.keys()) if input_topology_states: ref_field, _ = next(iter(input_topology_states.items())) ref_topo = self.input_fields[ref_field] ref_state = self.output_fields[output_field].topology_state for ifield, istate in input_topology_states.items(): itopo = self.input_fields[ifield] if not ( istate.dim == ref_state.dim and istate.axes == ref_state.axes and istate.memory_order == ref_state.memory_order ): msg = "\nInput topology state for field {} defined on topology {} does " msg += "not match reference input topology state {} defined on topology {} " msg += "for operator {}.\n" msg += ( "ComputationalGraphOperator default behaviour is to raise an error " ) msg += "when all input states do not match exactly.\n\n" msg += "Reference state: {}\n" msg += "Offending state: {}\n\n" msg += "This behaviour can be changed by overriding output_topology_state() for " msg += "your custom operator needs." msg = msg.format( ifield.name, itopo.tag, ref_field.name, ref_topo.tag, self.name, ref_state, istate, ) raise RuntimeError(msg) return ref_state.copy()
[docs] @debug def get_field_requirements(self): reqs = super().get_field_requirements() for f in self.input_fields: try: _ = reqs.get_input_requirement(f) except RuntimeError: reqs.update_inputs({f: reqs.get_output_requirement(f)[1]}) for f in self.output_fields: try: _ = reqs.get_output_requirement(f) except RuntimeError: reqs.update_outputs({f: reqs.get_input_requirement(f)[1]}) # Note: We enforce here the C-order to simplify the communication. As most part # of HySoP is in any- or c-order, this is not a big overhead (if field in # F-order, memory reordering is likely to be already present) for is_input, requirements in reqs.iter_requirements(): if requirements is None: continue (field, td, req) = requirements req.memory_order = MemoryOrdering.C_CONTIGUOUS return reqs
@debug def _check_inout_topology_states( self, ifields, itopology_states, ofields, otopology_states ): if (ifields != itopology_states) and (ofields != otopology_states): msg = "\nFATAL ERROR: {}::{}.handle_topologies()\n\n" msg = msg.format(type(self).__name__, self.name) if not ((ifields != itopology_states) and (ofields == otopology_states)): msg += "input_topology_states fields did not match operator's input Fields.\n" if ifields - itopology_states: msg += ( "input_topology_states are missing the following Fields: {}\n" ) msg = msg.format(ifields - itopology_states) else: msg += ( "input_topology_states is providing useless extra Fields: {}\n" ) msg = msg.format(itopology_states - ifields) if not ((ofields != otopology_states) and (ifields == itopology_states)): msg += "output_topology_states fields did not match operator's output Fields.\n" if ofields - otopology_states: msg += ( "output_topology_states are missing the following Fields: {}\n" ) msg = msg.format(ofields - otopology_states) else: msg += ( "output_topology_states is providing useless extra Fields: {}\n" ) msg = msg.format(otopology_states - ofields) raise RuntimeError(msg) @debug def _check_variables(self): """ Check input and output variables. Called automatically in ComputationalGraphNode.check() """ try: super()._check_variables() except TypeError: for (ik, iv), (ok, ov) in zip( self.input_fields.items(), self.output_fields.items() ): if not ( (iv is None and isinstance(ov, TopologyView)) or (ov is None and isinstance(iv, TopologyView)) ): if iv is None: msg = "Expected a Topology instance because input topo is None but got a {}.".format( ov.__class__ ) msg += "\nAll topologies are expected to be set after " msg += "ComputationalGraph.get_field_requirements() has been called." raise TypeError(msg) if ov is None: msg = "Expected a Topology instance because output topo is None but got a {}.".format( iv.__class__ ) msg += "\nAll topologies are expected to be set after " msg += "ComputationalGraph.get_field_requirements() has been called." raise TypeError(msg)
[docs] def discretize(self): super().discretize() # we can create the bridge ifield, ofield = None, None if self.variable in self.input_discrete_fields: ifield = self.input_discrete_fields[self.variable] if self.variable in self.output_discrete_fields: ofield = self.output_discrete_fields[self.variable] _is_source, _is_target = False, False source_topo, target_topo, source_id = None, None, None source_tstate, target_tstate = None, None if ifield is not None: _is_source = True source_topo = ifield.topology source_id = source_topo.mpi_params.task_id target_id = self._other_task_id source_tstate = ( source_topo.topology_state.dim, source_topo.topology_state.axes, source_topo.topology_state.memory_order, ) if DEBUG_REDISTRIBUTE != 0: print( "This is a redistribute of {} from source topology {}".format( self.variable.name, source_topo.tag ) ) if ofield is not None: _is_target = True target_topo = ofield.topology target_id = target_topo.mpi_params.task_id source_id = self._other_task_id if source_id is None else source_id target_tstate = ( target_topo.topology_state.dim, target_topo.topology_state.axes, target_topo.topology_state.memory_order, ) if DEBUG_REDISTRIBUTE != 0: print( "This is a redistribute of {} to target topology {}".format( self.variable.name, target_topo.tag ) ) self._synchronize(source_topo, target_topo) domain = first_not_None((source_topo, target_topo)).domain self._source_id, self._target_id = source_id, target_id # compute a tag from algebraic relation : # x,y \in [0;ss-1] and Tag = y+ss*(x+ss*(HASH/ss^2) # Therefore y=Tag%ss and x = (Tag/ss)%ss ss = domain.parent_comm.Get_size() + 1 h = int( sha1((self.variable.name + "RedistributeInter").encode()).hexdigest(), 16 ) % (1 << 31) basetag = ss * ss * (npw.uint32(h) / (100 * ss * ss)) self._tag = lambda x, y: npw.uint32(basetag + x * ss + y) # Exchange on root ranks first ... if domain.task_rank() == 0 and source_tstate != target_tstate: rcv_tstate = domain.parent_comm.sendrecv( first_not_None((source_tstate, target_tstate)), sendtag=self._other_task_id, recvtag=first_not_None((source_topo, target_topo)).mpi_params.task_id, dest=domain.task_root_in_parent(self._other_task_id), source=domain.task_root_in_parent(self._other_task_id), ) source_tstate, target_tstate = ( rcv_tstate if _ == None else _ for _ in (source_tstate, target_tstate) ) # ... then broadcast if _is_source: target_tstate = source_topo.mpi_params.comm.bcast(target_tstate, root=0) if _is_target: source_tstate = target_topo.mpi_params.comm.bcast(source_tstate, root=0) if not (source_tstate == target_tstate): msg = "Topology state mismatch between source and target." msg += "\nSource topology state:" msg += str(source_tstate) msg += "\nTarget topology state:" msg += str(target_tstate) raise RuntimeError(msg) if _is_source: assert all(source_topo.mesh.local_resolution == ifield.resolution) if _is_target: assert all(target_topo.mesh.local_resolution == ofield.resolution) # Create bridges and store comm types and indices if not domain.tasks_overlapping(source_id, target_id) is None: self.bridge = BridgeOverlap( source=source_topo, target=target_topo, source_id=source_id, target_id=target_id, dtype=self.dtype, order=get_mpi_order(first_not_None((ifield, ofield)).sdata), ) else: self.bridge = BridgeInter( current=first_not_None((source_topo, target_topo)), source_id=source_id, target_id=target_id, dtype=self.dtype, order=get_mpi_order(first_not_None((ifield, ofield)).sdata), ) # dictionary that maps the rank to the derived type needed (send if on source or recieve on target) self._comm_types, self._comm_indices = {}, {} if _is_source: self._comm_types[source_id] = self.bridge.transfer_types(task_id=source_id) self._comm_indices[source_id] = self.bridge.transfer_indices( task_id=source_id ) if _is_target: self._comm_types[target_id] = self.bridge.transfer_types(task_id=target_id) self._comm_indices[target_id] = self.bridge.transfer_indices( task_id=target_id ) self._has_requests = False if DEBUG_REDISTRIBUTE != 0: print("RedistributeInter communication indices", self._comm_indices) print("RedistributeInter communication types", self._comm_types) self.dFin = ifield self.dFout = ofield self._need_copy_before, self._need_copy_after = False, False if ifield is not None: if not ifield.backend.kind is Backend.HOST: self._need_copy_before = True self._dFin_data = ifield.backend.host_array_backend.empty_like( ifield.buffers[0] ).handle self._dFin_data[...] = 0.0 else: self._dFin_data = ifield.sdata.handle if ofield is not None: if not ofield.backend.kind is Backend.HOST: self._need_copy_after = True self._dFout_data = ofield.backend.host_array_backend.empty_like( ofield.buffers[0] ).handle self._dFout_data[...] = 0.0 else: self._dFout_data = ofield.sdata.handle self._is_source = _is_source self._is_target = _is_target
@op_apply def apply(self, **kwds): comm = self.bridge.comm rank = comm.Get_rank() types = self._comm_types indices = self._comm_indices dFin, dFout = self.dFin, self.dFout # TODO : Using GPU-aware MPI would simplify the usage of _memcpy if self._is_source: for rk, t in types[self._source_id].items(): if self._need_copy_before: _memcpy( self._dFin_data, self.dFin.sdata, target_indices=indices[self._source_id][rk], source_indices=indices[self._source_id][rk], skind=Backend.OPENCL, tkind=Backend.HOST, ) sendtag = self._tag(rk + 1, rank + 1) comm.Isend([self._dFin_data, 1, t], dest=rk, tag=sendtag) if self._is_target: for rk, t in types[self._target_id].items(): recvtag = self._tag(rank + 1, rk + 1) comm.Recv([self._dFout_data, 1, t], source=rk, tag=recvtag) if self._need_copy_after: _memcpy( self.dFout.sdata, self._dFout_data, target_indices=indices[self._target_id][rk], source_indices=indices[self._target_id][rk], skind=Backend.HOST, tkind=Backend.OPENCL, ) self.dFout.exchange_ghosts()
[docs] class RedistributeInterParam(ComputationalGraphOperator): """parameter transfer between two operators/topologies. Source and target must: *be MPIParams defined on different communicators """
[docs] @classmethod def supports_mpi(cls): return True
def __new__( cls, parameter, source_topo, target_topo, other_task_id, domain, **kwds ): return super().__new__(cls, **kwds) def __init__( self, parameter, source_topo, target_topo, other_task_id, domain, **kwds ): """ Communicate parameter through tasks parameter ---------- parameter: tuple of ScalarParameter or TensorParameter parameters to communicate source_topo: MPIParam target_topo: MPIParam """ check_instance(parameter, tuple, values=(ScalarParameter, TensorParameter)) check_instance(source_topo, MPIParams, allow_none=True) check_instance(target_topo, MPIParams, allow_none=True) input_fields, output_fields = {}, {} input_params, output_params = {}, {} assert not (source_topo is None and target_topo is None) if not source_topo is None and source_topo.on_task: input_params = {p: source_topo for p in parameter} if not target_topo is None and target_topo.on_task: output_params = {p: target_topo for p in parameter} super().__init__( mpi_params=first_not_None(source_topo, target_topo), input_params=input_params, output_params=output_params, input_fields=input_fields, output_fields=output_fields, **kwds, ) self.initialized = True self.domain = domain self.source_task = other_task_id if source_topo is None else source_topo.task_id self.target_task = other_task_id if target_topo is None else target_topo.task_id self.task_is_source = domain.is_on_task(self.source_task) self.task_is_target = domain.is_on_task(self.target_task) if self.task_is_source: assert source_topo.on_task if self.task_is_target: assert target_topo.on_task self.inter_comm = domain.task_intercomm( self.target_task if self.task_is_source else self.source_task ) if self.inter_comm.is_inter: # Disjoint tasks with real inter-communicator self._the_apply = self._apply_intercomm elif self.inter_comm.is_intra: # Overlapping tasks using an intra-communicator fron union of tasks procs self._the_apply = self._apply_intracomm self._all_params_by_type = {} for p in sorted(self.parameters, key=lambda _: _.name): if not p.dtype in self._all_params_by_type: self._all_params_by_type[p.dtype] = [] self._all_params_by_type[p.dtype].append(p) self._send_temp_by_type = { t: np.zeros((len(self._all_params_by_type[t]),), dtype=t) for t in self._all_params_by_type.keys() } self._recv_temp_by_type = { t: np.zeros((len(self._all_params_by_type[t]),), dtype=t) for t in self._all_params_by_type.keys() } @op_apply def apply(self, **kwds): self._the_apply(**kwds) def _apply_intercomm(self, **kwds): """Disjoint tasks so inter-comm bcast is needed.""" for t in self._all_params_by_type.keys(): if self.task_is_source: self._send_temp_by_type[t][...] = [ p() for p in self._all_params_by_type[t] ] self.inter_comm.bcast( self._send_temp_by_type[t], root=MPI.ROOT if self.domain.task_rank() == 0 else MPI.PROC_NULL, ) if self.task_is_target: self._recv_temp_by_type[t] = self.inter_comm.bcast( self._send_temp_by_type[t], root=0 ) for p, v in zip( self._all_params_by_type[t], self._recv_temp_by_type[t] ): p.value = v def _apply_intracomm(self, **kwds): """Communicator is an intra-communicator defined as tasks' comm union. Single broadcast is enough. """ for t in self._all_params_by_type.keys(): if self.task_is_source and self.domain.task_rank() == 0: self._send_temp_by_type[t][...] = [ p() for p in self._all_params_by_type[t] ] self._recv_temp_by_type[t] = self.inter_comm.bcast( self._send_temp_by_type[t], self.domain.task_root_in_parent(self.source_task), ) if self.task_is_target: for p, v in zip( self._all_params_by_type[t], self._recv_temp_by_type[t] ): p.value = v